{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Part 8 - Introduction to Plans\n",
    "\n",
    "\n",
    "### Context \n",
    "\n",
    "We introduce here an object which is crucial to scale to industrial Federated Learning: the Plan. It reduces dramatically the bandwidth usage, allows asynchronous schemes and give more autonomy to remote devices. The original concept of plan can be found in the paper [Towards Federated Learning at Scale: System Design](https://arxiv.org/pdf/1902.01046.pdf), but it has been adapted to our needs in the PySyft library.\n",
    "\n",
    "A Plan is intended to store a sequence of torch operations, just like a function, but it allows to send this sequence of operations to remote workers and to keep a reference to it. This way, to compute remotely this sequence of $n$ operations on some remote input referenced through pointers, instead of sending $n$ messages you need now to send a single message with the references of the plan and the pointers. You can also provide tensors with your function (that we call _state tensors_) to have extended functionalities. Plans can be seen either like a function that you can send, or like a class which can also be sent and executed remotely. Hence, for high level users, the notion of plan disappears and is replaced by a magic feature which allow to send to remote workers arbitrary functions containing sequential torch functions.\n",
    "\n",
    "One thing to notice is that the class of functions that you can transform into plans is currently limited to sequences of hooked torch operations exclusively. This excludes in particular logical structures like `if`, `for` and `while` statements, even if we are working to have workarounds soon. _To be completely precise, you can use these but the logical path you take (first `if` to False and 5 loops in `for` for example) in the first computation of your plan will be the one kept for all the next computations, which we want to avoid in the majority of cases._\n",
    "\n",
    "Authors:\n",
    "- Théo Ryffel - Twitter [@theoryffel](https://twitter.com/theoryffel) - GitHub: [@LaRiffle](https://github.com/LaRiffle)\n",
    "- Bobby Wagner - Twitter [@bobbyawagner](https://twitter.com/bobbyawagner) - GitHub: [@robert-wagner](https://github.com/robert-wagner)\n",
    "- Marianne Monteiro - Twitter [@hereismari](https://twitter.com/hereismari) - GitHub: [@mari-linhares](https://github.com/mari-linhares)\t"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Imports and model specifications\n",
    "\n",
    "First let's make the official imports."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "And than those specific to PySyft, with one important note: **the local worker should not be a client worker.** *Non client workers can store objects and we need this ability to run a plan.*"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import syft as sy  # import the Pysyft library\n",
    "hook = sy.TorchHook(torch)  # hook PyTorch ie add extra functionalities \n",
    "\n",
    "# IMPORTANT: Local worker should not be a client worker\n",
    "hook.local_worker.is_client_worker = False\n",
    "\n",
    "\n",
    "server = hook.local_worker"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We define remote workers or _devices_, to be consistent with the notions provided in the reference article.\n",
    "We provide them with some data."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "x11 = torch.tensor([-1, 2.]).tag('input_data')\n",
    "x12 = torch.tensor([1, -2.]).tag('input_data2')\n",
    "x21 = torch.tensor([-1, 2.]).tag('input_data')\n",
    "x22 = torch.tensor([1, -2.]).tag('input_data2')\n",
    "\n",
    "device_1 = sy.VirtualWorker(hook, id=\"device_1\", data=(x11, x12)) \n",
    "device_2 = sy.VirtualWorker(hook, id=\"device_2\", data=(x21, x22))\n",
    "devices = device_1, device_2"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Basic example\n",
    "\n",
    "Let's define a function that we want to transform into a plan. To do so, it's as simple as adding a decorator above the function definition!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "@sy.func2plan()\n",
    "def plan_double_abs(x):\n",
    "    x = x + x\n",
    "    x = torch.abs(x)\n",
    "    return x"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's check, yes we have now a plan!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plan_double_abs"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "To use a plan, you need two things: to build the plan (_ie register the sequence of operations present in the function_) and to send it to a worker / device. Fortunately you can do this very easily!\n",
    "\n",
    "#### Building a plan\n",
    "\n",
    "To build a plan you just need to call it on some data.\n",
    "\n",
    "Let's first get a reference to some remote data: a request is sent over the network and a reference pointer is returned."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pointer_to_data = device_1.search('input_data')[0]\n",
    "pointer_to_data"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "If we tell the plan it must be executed remotely on the device`location:device_1`... we'll get an error because the plan was not built yet."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plan_double_abs.is_built"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Sending non-built Plan will fail\n",
    "try:\n",
    "    plan_double_abs.send(device_1)\n",
    "except RuntimeError as error:\n",
    "    print(error)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "To build a plan you just need to call `build` on the plan and pass the arguments needed to execute the plan (a.k.a some data). When a plan is built all the commands are executed sequentially by the local worker, and are catched by the plan and stored in its `readable_plan` attribute!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plan_double_abs.build(torch.tensor([1., -2.]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plan_double_abs.is_built"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "If we try to send the plan now it works!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# This cell is executed successfully\n",
    "pointer_plan = plan_double_abs.send(device_1)\n",
    "pointer_plan"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "As with then tensors, we get a pointer to the object sent. Here it is simply called a `PointerPlan`."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "One important thing to remember is that when a plan is built we pre-set ahead of computation the id(s) where the result(s) should be stored. This will allow to send commands asynchronously, to already have a reference to a virtual result and to continue local computations without waiting for the remote result to be computed. One major application is when you require computation of a batch on device_1 and don't want to wait for this computation to end to launch another batch computation on device_2."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Running a Plan remotely\n",
    "\n",
    "We can now remotely run the plan by calling the pointer to the plan with a  pointer to some data. This issues a  command to run this plan remotely, so that the predefined location of the output of the plan now contains the result (remember we pre-set location of result ahead of computation). This also requires a single communication round.\n",
    "\n",
    "The result is simply a pointer, just like when you call an usual hooked torch function!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pointer_to_result = pointer_plan(pointer_to_data)\n",
    "print(pointer_to_result)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "And you can simply ask the value back."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pointer_to_result.get()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Towards a concrete example\n",
    "\n",
    "But what we want to do is to apply Plan to deep and federated learning, right? So let's look to a slightly more complicated example, using neural networks as you might be willing to use them.\n",
    "Note that we are now transforming a class into a plan. To do so, we inherit our class from sy.Plan (instead of inheriting from nn.Module)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Net(sy.Plan):\n",
    "    def __init__(self):\n",
    "        super(Net, self).__init__()\n",
    "        self.fc1 = nn.Linear(2, 3)\n",
    "        self.fc2 = nn.Linear(3, 2)\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = F.relu(self.fc1(x))\n",
    "        x = self.fc2(x)\n",
    "        return F.log_softmax(x, dim=0)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "net = Net()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "net"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's build the plan using some mock data."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "net.build(torch.tensor([1., 2.]))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We now send the plan to a remote worker"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pointer_to_net = net.send(device_1)\n",
    "pointer_to_net"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's retrieve some remote data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pointer_to_data = device_1.search('input_data')[0]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Then, the syntax is just like normal remote sequential execution, that is, just like local execution. But compared to classic remote execution, there is a single communication round for each execution."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pointer_to_result = pointer_to_net(pointer_to_data)\n",
    "pointer_to_result"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "And we get the result as usual!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pointer_to_result.get()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Et voilà! We have seen how to dramatically reduce the communication between the local worker (or server) and the remote devices!"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Switch between workers\n",
    "\n",
    "One major feature that we want to have is to use the same plan for several workers, that we would change depending on the remote batch of data we are considering.\n",
    "In particular, we don't want to rebuild the plan each time we change of worker. Let's see how we do this, using the previous example with our small network."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Net(sy.Plan):\n",
    "    def __init__(self):\n",
    "        super(Net, self).__init__()\n",
    "        self.fc1 = nn.Linear(2, 3)\n",
    "        self.fc2 = nn.Linear(3, 2)\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = F.relu(self.fc1(x))\n",
    "        x = self.fc2(x)\n",
    "        return F.log_softmax(x, dim=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "net = Net()\n",
    "\n",
    "# Build plan\n",
    "net.build(torch.tensor([1., 2.]))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Here are the main steps we just executed"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pointer_to_net_1 = net.send(device_1)\n",
    "pointer_to_data = device_1.search('input_data')[0]\n",
    "pointer_to_result = pointer_to_net_1(pointer_to_data)\n",
    "pointer_to_result.get()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "And actually you can build other PointerPlans from the same plan, so the syntax is the same to run remotely a plan on another device"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pointer_to_net_2 = net.send(device_2)\n",
    "pointer_to_data = device_2.search('input_data')[0]\n",
    "pointer_to_result = pointer_to_net_2(pointer_to_data)\n",
    "pointer_to_result.get()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "> Note: Currently, with Plan classes, you can only use a single method and you have to name it \"forward\"."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Automatically building plans that are functions\n",
    "\n",
    "For functions (`@` `sy.func2plan`) we can automatically build the plan with no need to explicitly calling `build`, actually in the moment of creation the plan is already built.\n",
    "\n",
    "To get this functionality the only thing you need to change when creating a plan is setting an argument to the decorator called `args_shape` which should be a list containing the shapes of each argument."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "@sy.func2plan(args_shape=[(-1, 1)])\n",
    "def plan_double_abs(x):\n",
    "    x = x + x\n",
    "    x = torch.abs(x)\n",
    "    return x\n",
    "\n",
    "plan_double_abs.is_built"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The `args_shape` parameter is used internally to create mock tensors with the given shape which are used to build the plan."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "@sy.func2plan(args_shape=[(1, 2), (-1, 2)])\n",
    "def plan_sum_abs(x, y):\n",
    "    s = x + y\n",
    "    return torch.abs(s)\n",
    "\n",
    "plan_sum_abs.is_built"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "You can also provide state elements to functions!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "@sy.func2plan(args_shape=[(1,)], state=(torch.tensor([1]), ))\n",
    "def plan_abs(x, state):\n",
    "    bias, = state.read()\n",
    "    x = x.abs()\n",
    "    return x + bias"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pointer_plan = plan_abs.send(device_1)\n",
    "x_ptr = torch.tensor([-1, 0]).send(device_1)\n",
    "p = pointer_plan(x_ptr)\n",
    "p.get()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "To learn more about this, you can discover how we use Plans with Protocols in Tutorial Part 8 bis!"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Star PySyft on GitHub\n",
    "\n",
    "The easiest way to help our community is just by starring the repositories! This helps raise awareness of the cool tools we're building.\n",
    "\n",
    "- [Star PySyft](https://github.com/OpenMined/PySyft)\n",
    "\n",
    "### Pick our tutorials on GitHub!\n",
    "\n",
    "We made really nice tutorials to get a better understanding of what Federated and Privacy-Preserving Learning should look like and how we are building the bricks for this to happen.\n",
    "\n",
    "- [Checkout the PySyft tutorials](https://github.com/OpenMined/PySyft/tree/master/examples/tutorials)\n",
    "\n",
    "\n",
    "### Join our Slack!\n",
    "\n",
    "The best way to keep up to date on the latest advancements is to join our community! \n",
    "\n",
    "- [Join slack.openmined.org](http://slack.openmined.org)\n",
    "\n",
    "### Join a Code Project!\n",
    "\n",
    "The best way to contribute to our community is to become a code contributor! If you want to start \"one off\" mini-projects, you can go to PySyft GitHub Issues page and search for issues marked `Good First Issue`.\n",
    "\n",
    "- [Good First Issue Tickets](https://github.com/OpenMined/PySyft/issues?q=is%3Aopen+is%3Aissue+label%3A%22good+first+issue%22)\n",
    "\n",
    "### Donate\n",
    "\n",
    "If you don't have time to contribute to our codebase, but would still like to lend support, you can also become a Backer on our Open Collective. All donations go toward our web hosting and other community expenses such as hackathons and meetups!\n",
    "\n",
    "- [Donate through OpenMined's Open Collective Page](https://opencollective.com/openmined)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
